"""
[1] Mastering Diverse Domains through World Models - 2023
D. Hafner, J. Pasukonis, J. Ba, T. Lillicrap
https://arxiv.org/pdf/2301.04104v1.pdf

[2] Mastering Atari with Discrete World Models - 2021
D. Hafner, T. Lillicrap, M. Norouzi, J. Ba
https://arxiv.org/pdf/2010.02193.pdf
"""
from ray.rllib.utils.framework import try_import_tf

_, tf, _ = try_import_tf()


class RewardPredictorLayer(tf.keras.layers.Layer):
    """A layer outputting reward predictions using K bins and two-hot encoding.

    This layer is used in two models in DreamerV3: The reward predictor of the world
    model and the value function. K is 255 by default (see [1]) and doesn't change
    with the model size.

    Possible predicted reward/values range from symexp(-20.0) to symexp(20.0), which
    should cover any possible environment. Outputs of this layer are generated by
    generating logits/probs via a single linear layer, then interpreting the probs
    as weights for a weighted average of the different possible reward (binned) values.
    """

    def __init__(
        self,
        *,
        num_buckets: int = 255,
        lower_bound: float = -20.0,
        upper_bound: float = 20.0,
        trainable: bool = True,
    ):
        """Initializes a RewardPredictorLayer instance.

        Args:
            num_buckets: The number of buckets to create. Note that the number of
                possible symlog'd outcomes from the used distribution is
                `num_buckets` + 1:
                lower_bound --bucket-- o[1] --bucket-- o[2] ... --bucket-- upper_bound
                o=outcomes
                lower_bound=o[0]
                upper_bound=o[num_buckets]
            lower_bound: The symlog'd lower bound for a possible reward value.
                Note that a value of -20.0 here already allows individual (actual env)
                rewards to be as low as -400M. Buckets will be created between
                `lower_bound` and `upper_bound`.
            upper_bound: The symlog'd upper bound for a possible reward value.
                Note that a value of +20.0 here already allows individual (actual env)
                rewards to be as high as 400M. Buckets will be created between
                `lower_bound` and `upper_bound`.
        """
        self.num_buckets = num_buckets
        super().__init__(name=f"reward_layer_{self.num_buckets}buckets")

        self.lower_bound = lower_bound
        self.upper_bound = upper_bound
        self.reward_buckets_layer = tf.keras.layers.Dense(
            units=self.num_buckets,
            activation=None,
            # From [1]:
            # "We further noticed that the randomly initialized reward predictor and
            # critic networks at the start of training can result in large predicted
            # rewards that can delay the onset of learning. We initialize the output
            # weights of the reward predictor and critic to zeros, which effectively
            # alleviates the problem and accelerates early learning."
            kernel_initializer="zeros",
            bias_initializer="zeros",  # zero-bias is default anyways
            trainable=trainable,
        )

    def call(self, inputs):
        """Computes the expected reward using N equal sized buckets of possible values.

        Args:
            inputs: The input tensor for the layer, which computes the reward bucket
                weights (logits). [B, dim].

        Returns:
            A tuple consisting of the expected rewards and the logits that parameterize
            the tfp `FiniteDiscrete` distribution object. To get the individual bucket
            probs, do `[FiniteDiscrete object].probs`.
        """
        # Compute the `num_buckets` weights.
        assert len(inputs.shape) == 2
        logits = tf.cast(self.reward_buckets_layer(inputs), tf.float32)
        # out=[B, `num_buckets`]

        # Compute the expected(!) reward using the formula:
        # `softmax(Linear(x))` [vectordot] `possible_outcomes`, where
        # `possible_outcomes` is the even-spaced (binned) encoding of all possible
        # symexp'd reward/values.
        # [2]: "The mean of the reward predictor pφ(ˆrt | zˆt) is used as reward
        # sequence rˆ1:H."
        probs = tf.nn.softmax(logits)
        possible_outcomes = tf.linspace(
            self.lower_bound,
            self.upper_bound,
            self.num_buckets,
        )
        # probs=possible_outcomes=[B, `num_buckets`]

        # Simple vector dot product (over last dim) to get the mean reward
        # weighted sum, where all weights sum to 1.0.
        expected_rewards = tf.reduce_sum(probs * possible_outcomes, axis=-1)
        # expected_rewards=[B]

        return expected_rewards, logits
